import torch
import torch.nn.functional as F
from torch import nn


def Reconstruction_loss(fea1, fea2):

    loss = F.mse_loss(fea1, fea2)
    return loss


# fea1 = torch.randn(20, 90, 45)
# fea2 = torch.randn(20, 90, 45)
# rec_loss = Reconstruction_loss(fea1, fea2)
# print(rec_loss)


def CS_loss(fea1, fea2):
    cosine_similarity = nn.CosineSimilarity(dim=2)
    loss_ss = torch.mean(cosine_similarity(fea1, fea2))
    return loss_ss

# fea1 = torch.randn(20, 90, 16)
# fea2 = torch.randn(20, 90, 16)
# loss_ss = CS_loss(fea1, fea2)
# print(loss_ss)


def SH_weight_CS(tensor1, tensor2, hind):
    batch_size = tensor1.shape[0]
    tensor1_reshaped = tensor1.view(batch_size, -1)
    tensor2_reshaped = tensor2.view(batch_size, -1)
    cosine_sim_matrix = F.cosine_similarity(tensor1_reshaped.unsqueeze(2), tensor2_reshaped.unsqueeze(1), dim=1)
    return cosine_sim_matrix.view(batch_size, 90, hind)


def S_weights(nets, hind):
    cosine_sim_matrix = torch.zeros((nets.shape[0], 90, hind)).cuda()
    T = nets.shape[1]
    for i in range(T - 1):
        for j in range(i + 1, T):
            net1 = nets[:, i, :, :]
            net2 = nets[:, j, :, :]
            cs = SH_weight_CS(net1, net2, hind)
            cosine_sim_matrix = cosine_sim_matrix + cs

    cosine_sim_matrix = 2 * cosine_sim_matrix / (T * (T - 1))
    return cosine_sim_matrix


def T_weights(nets, hind):
    cosine_sim_matrix = torch.zeros((nets.shape[0], 90, hind)).cuda()
    T = nets.shape[1]
    for i in range(T - 1):
        for j in range(i + 1, T):
            net1 = nets[:, i, :, :]
            net2 = nets[:, j, :, :]
            cs = SH_weight_CS(net1, net2, hind)
            cosine_sim_matrix = cosine_sim_matrix +(cs-1)

    cosine_sim_matrix = 2 * cosine_sim_matrix / (T * (T - 1))
    return cosine_sim_matrix


# nets = torch.randn(20, 3, 90, 45).cuda()
# out = SH_weights(nets, 45)
# print(out)


